-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: add activation checkpointing to unet #8554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add activation checkpointing to unet #8554
Conversation
WalkthroughAdds monai/networks/blocks/activation_checkpointing.py implementing ActivationCheckpointWrapper that applies torch.utils.checkpoint.checkpoint(..., use_reentrant=False) to a wrapped nn.Module. Adds CheckpointUNet(UNet) in monai/networks/nets/unet.py which overrides _get_connection_block to wrap the connection subblock, down_path, and up_path with ActivationCheckpointWrapper and updates Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Areas to pay attention to:
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)
29-33: Add a brief class docstring to the wrapper.
Improves discoverability and meets docstring guidelines.Apply this diff:
class _ActivationCheckpointWrapper(nn.Module): - def __init__(self, module: nn.Module) -> None: + """Apply activation checkpointing to the wrapped module during training.""" + def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module
134-135: Document the newuse_checkpointingarg in the class docstring and user docs.
State trade-offs (memory vs compute), that it’s training-only, incompatible withtorch.no_grad, and preserves RNG by default.Proposed docstring snippet to add under “Args”:
use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.I can open a docs patch and add a short example enabling the flag.
163-164: Static wrapping caveat: runtime flips won’t take effect.
After init, changingself.use_checkpointingwon’t rewrap existing blocks. Either document this or add a small helper to (re)build the model if you expect runtime toggling.Do you expect users to toggle this at runtime? If yes, I can sketch a safe rewrap helper.
210-212: Checkpointing scope is subblock-only; consider an optional broader policy.
Current placement is a good default. If more memory is needed, offer a policy to also wrapdown_path/up_path(with a warning about extra compute).I can add a
checkpoint_policy: Literal["subblock","all"] = "subblock"in__init__and wire it here on request.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(5 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: packaging
🔇 Additional comments (1)
monai/networks/nets/unet.py (1)
16-21: Imports for checkpointing look good.
castandcheckpointare appropriate for the new wrapper.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Fábio S. Ferreira <ferreira.fabio80@gmail.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
29-43: Good guard + compatibility fallback.
Training/grad-enabled checks anduse_reentrant=FalsewithTypeErrorfallback are the right call. This addresses the prior review note.
🧹 Nitpick comments (5)
monai/networks/nets/unet.py (5)
29-43: Avoid per-iteration TypeError cost: detectuse_reentrantsupport once.
Resolve support at import/init time to prevent raising an exception every forward on older torch.Apply:
@@ -class _ActivationCheckpointWrapper(nn.Module): +_SUPPORTS_USE_REENTRANT: bool | None = None + +class _ActivationCheckpointWrapper(nn.Module): @@ - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training and torch.is_grad_enabled() and x.requires_grad: - try: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) - except TypeError: - # Fallback for older PyTorch without `use_reentrant` - return cast(torch.Tensor, checkpoint(self.module, x)) - return cast(torch.Tensor, self.module(x)) + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training and torch.is_grad_enabled() and x.requires_grad: + global _SUPPORTS_USE_REENTRANT + if _SUPPORTS_USE_REENTRANT is None: + try: + # probe once + checkpoint(self.module, x, use_reentrant=False) # type: ignore[arg-type] + _SUPPORTS_USE_REENTRANT = True + except TypeError: + _SUPPORTS_USE_REENTRANT = False + except Exception: + # do not change behavior on unexpected errors; fall back below + _SUPPORTS_USE_REENTRANT = False + if _SUPPORTS_USE_REENTRANT: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, checkpoint(self.module, x)) + return cast(torch.Tensor, self.module(x))Add outside the hunk (file header):
import inspect # if you switch to signature probing instead of try/exceptNote: PyTorch recommends passing
use_reentrantexplicitly going forward. (docs.pytorch.org)
29-43: TorchScript: make wrapper script-safe.
try/exceptand dynamic checkpoint calls won’t script. Short-circuit under scripting.Apply:
@@ - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + # Avoid checkpoint in scripted graphs + return cast(torch.Tensor, self.module(x))
29-43: Docstring completeness.
Add Google-style docstrings for the wrapper’s class/init/forward (inputs, returns, raises).Example:
@@ -class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" +class _ActivationCheckpointWrapper(nn.Module): + """Wrap a module and apply activation checkpointing during training. + + Args: + module: The submodule to checkpoint. + + Returns: + torch.Tensor: Output tensor from the wrapped submodule. + + Raises: + RuntimeError: If checkpoint fails at runtime. + """
90-92: Tighten theuse_checkpointingdocstring and add a BN caveat.
Keep it on one Args entry and note the BatchNorm limitation.- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: If True, apply activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Bypassed in eval and when gradients are disabled. Note: avoid with + BatchNorm layers due to running-stat updates during recomputation. Defaults to False.Reference on RNG determinism (dropout is handled by default). (docs.pytorch.org)
217-219: Wrap site is fine; consider BN detection here instead of inside the wrapper.
Keeps policy close to where wrapping occurs and avoids per-instance warnings.- if self.use_checkpointing: - subblock = _ActivationCheckpointWrapper(subblock) + if self.use_checkpointing: + has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in subblock.modules()) + if has_bn: + warnings.warn( + "Skipping activation checkpointing for this subblock (contains BatchNorm).", + RuntimeWarning, + ) + else: + subblock = _ActivationCheckpointWrapper(subblock) return nn.Sequential(down_path, SkipConnection(subblock), up_path)Rationale: avoids double-updating BN stats during recomputation. (github.com)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)
16-16: LGTM: imports are correct and scoped.
castandcheckpointare needed by the wrapper; no issues.Also applies to: 20-20
141-141: Public API addition: ensure tests and docs cover new flag.
Add unit tests for parity (on/off), eval bypass, andno_grad()bypass; document in release notes/configs.Proposed minimal tests:
- Forward/backward equivalence within tolerance for
use_checkpointing={False,True}on a tiny UNet.model.eval()andtorch.no_grad()paths skip checkpoint (nocheckpointmocks invoked).- Autocast path under CUDA runs without dtype mismatches.
Do you want a patch with pytest tests added under
tests/test_unet_checkpointing.py?
170-170: LGTM: stores flag on the instance.
No concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)
29-51: Warn once for BN, and add param/return docstring per guidelines.Avoid repeated RuntimeWarnings and document the wrapper’s contract.
class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" + """Apply activation checkpointing to the wrapped module during training. + + Args: + module: submodule to wrap. + Returns: + torch.Tensor: output of the wrapped module. + Warnings: + Skips checkpointing and emits a RuntimeWarning if the submodule contains + BatchNorm to avoid double-updating running stats. + """ def __init__(self, module: nn.Module) -> None: super().__init__() # Pre-detect BatchNorm presence for fast path self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) self.module = module + self._bn_warned = False def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and torch.is_grad_enabled() and x.requires_grad: if self._has_bn: - warnings.warn( - "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " - "running statistics during recomputation.", - RuntimeWarning, - ) + if not self._bn_warned: + warnings.warn( + "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " + "running statistics during recomputation.", + RuntimeWarning, + ) + self._bn_warned = True return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: # Fallback for older PyTorch without `use_reentrant` return cast(torch.Tensor, checkpoint(self.module, x)) return cast(torch.Tensor, self.module(x))Minimal tests to add:
- Training vs eval parity (values match with/without checkpointing).
- BN subblock emits RuntimeWarning and bypasses checkpointing.
- Guard under
torch.no_grad()and whenrequires_grad=False.
99-101: Clarifyuse_checkpointingbehavior in docs (BN, grad, training-only, build-time).Make expectations explicit for users.
- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Checkpointing is active only when `self.training` is True, gradients + are enabled, and inputs require gradients; it is bypassed in eval and when grads are disabled. + Sub-blocks containing BatchNorm are not checkpointed to avoid double-updating running stats. + Note: this is a build-time option; changing it after initialization will not retroactively wrap existing + sub-blocks. Defaults to False.
179-180: Flag is build-time only; consider asserting or documenting.Changing
self.use_checkpointingpost-init has no effect since wrappers are created during construction. The doc update above covers this; alternatively, convert to a read-only attribute to avoid confusion.
226-228: Wrapping only the subblock is OK; consider optional coverage toggle.If desired later, expose an opt-in to also wrap
down/upblocks for additional savings.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)
16-21: Imports look good.Scoped import of
checkpointpluscastis appropriate.
150-151: API change is sensible and non-breaking.Parameter added at the end; default preserves behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
29-43: Skip checkpointing for subblocks with BatchNorm to avoid double-updating running stats.
Checkpoint recompute updates BN running stats twice under training. Detect BN in the wrapped module and bypass checkpointing with a warning.Apply this diff:
class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" + """Apply activation checkpointing to the wrapped module during training. + Skips checkpointing for submodules containing BatchNorm to avoid double-updating + running statistics during recomputation. + """ def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module + self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training and torch.is_grad_enabled() and x.requires_grad: + if self.training and torch.is_grad_enabled() and x.requires_grad: + if self._has_bn: + warnings.warn( + "Activation checkpointing skipped for a subblock containing BatchNorm " + "to avoid double-updating running statistics during recomputation.", + RuntimeWarning, + ) + return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: # Fallback for older PyTorch without `use_reentrant` return cast(torch.Tensor, checkpoint(self.module, x)) return cast(torch.Tensor, self.module(x))
🧹 Nitpick comments (3)
monai/networks/nets/unet.py (3)
90-92: Clarify arg docs and surface BN caveat.
Tighten wording and document BN behavior for transparency.Apply this diff:
- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: If True, applies activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Bypassed in eval mode and when gradients are disabled. + Note: sub-blocks containing BatchNorm are executed without checkpointing to avoid double-updating + running statistics. Defaults to False.
217-219: Placement of wrapper is sensible; consider optional breadth control.
Future enhancement: expose a knob to checkpoint down/up paths too for deeper memory savings on very deep nets.
141-142: Add tests to lock behavior.
- Parity: forward/backward equivalence (outputs/grad norms) with vs. without checkpointing.
- Modes: train vs. eval; torch.no_grad().
- Norms: with InstanceNorm and with BatchNorm (assert BN path skips with warning).
I can draft unit tests targeting UNet’s smallest config to keep runtime minimal—want me to open a follow-up?
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)
16-21: LGTM: imports for cast/checkpoint are correct.
Direct import of checkpoint and use of typing.cast are appropriate.
35-42: Validate AMP behavior under fallback (reentrant) checkpointing.
Older Torch (fallback path) may not replay autocast exactly; please verify mixed-precision parity.Minimal check: run a forward/backward with torch.autocast and compare loss/grad norms with and without checkpointing on a small UNet to ensure deltas are within numerical noise.
141-142: API addition looks good.
Name and default match MONAI conventions.
|
Hi @ferreirafabio80 thanks for the contribution but I would suggest this isn't necessarily the way to go with adapting this class. Perhaps instead you can create a subclass of class CheckpointUNet(UNet):
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
subblock = _ActivationCheckpointWrapper(subblock)
return super()._get_connection_block(down_path, up_path, subblock)This would suffice for your own use if you just wanted such a definition. I think the I see also that |
|
Hi @ericspod, thank you for your comments. Yes, that also works. I've defined a subclass and overridden the method as you suggested. Regarding the I was probably being extremely careful with the checks in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (2)
monai/networks/nets/unet.py (2)
35-36: Missing training and gradient guards causes eval overhead and no_grad crashes.The forward unconditionally calls checkpoint. This will:
- Apply checkpointing during inference (eval mode) → unnecessary compute overhead.
- Fail under
torch.no_grad()→ runtime error.Apply this diff:
def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))
29-37: BatchNorm in checkpointed subblocks will double-update running stats.Checkpoint recomputes the forward pass during backward, causing BatchNorm layers to update
running_mean/running_vartwice per training step, skewing statistics.Consider detecting BatchNorm in
__init__and either warning or skipping checkpoint:class _ActivationCheckpointWrapper(nn.Module): """Apply activation checkpointing to the wrapped module during training.""" def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module + self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + if self._has_bn: + warnings.warn( + "Activation checkpointing skipped for subblock with BatchNorm to avoid double-update of running stats.", + RuntimeWarning, + ) + return cast(torch.Tensor, self.module(x)) + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(4 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)
16-16: LGTM on imports.Both
castandcheckpointare used in the new wrapper and are correctly imported.Also applies to: 20-20
316-316: Clarify checkpointing scope: onlysubblockvs. entire connection block.Only
subblock(the recursive nested structure) is wrapped, whiledown_pathandup_path(encoder/decoder convolutions at each level) are not checkpointed. Is this intentional?Typical UNet checkpointing strategies checkpoint entire encoder/decoder blocks for maximum memory savings. Consider whether
down_pathandup_pathshould also be wrapped, or document the rationale for checkpointing only the recursive substructure.# Alternative: checkpoint all three components def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: down_path = _ActivationCheckpointWrapper(down_path) up_path = _ActivationCheckpointWrapper(up_path) subblock = _ActivationCheckpointWrapper(subblock) return super()._get_connection_block(down_path, up_path, subblock)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
monai/networks/nets/unet.py (2)
26-26: ExportCheckpointUNetvia__all__.Public class isn’t exported; importing it from
monai.networks.netswill fail. Add it to__all__.-__all__ = ["UNet", "Unet"] +__all__ = ["UNet", "Unet", "CheckpointUNet"]
36-37: Guard checkpoint under grad-enabled training only.Forward unconditionally checkpointing will crash under
torch.no_grad()(common in eval) because checkpoint requires gradients. Gate onself.training,torch.is_grad_enabled(), and a grad-requiring input before calling checkpoint.def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
Sorry for the delay, I think we should put this into a new file in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/nets/unet.py (1)
25-25: Export the new public variant.
ExposeCheckpointUNetalongsideUNet/Unetsofrom monai.networks.nets import CheckpointUNetworks consistently.-__all__ = ["UNet", "Unet"] +__all__ = ["UNet", "Unet", "CheckpointUNet"]
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
302-307: Docstring required forCheckpointUNet.
Add a Google-style class docstring describing the checkpointing behavior, inherited args, and trade-offs per our guidelines.class CheckpointUNet(UNet): + """ + UNet variant that wraps internal connection blocks with activation checkpointing. + + See `UNet` for constructor arguments. During training with gradients enabled, + intermediate activations inside encoder–decoder connections are recomputed in + the backward pass to reduce peak memory usage at the cost of extra compute. + """
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/networks/blocks/activation_checkpointing.py(1 hunks)monai/networks/nets/unet.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/blocks/activation_checkpointing.pymonai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: build-docs
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
monai/networks/nets/unet.py (1)
25-25: Sort__all__alphabetically.-__all__ = ["UNet", "Unet", "CheckpointUNet"] +__all__ = ["CheckpointUNet", "UNet", "Unet"]
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
🪛 Ruff (0.14.3)
monai/networks/nets/unet.py
25-25: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
306-306: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
🔇 Additional comments (1)
monai/networks/nets/unet.py (1)
20-20: LGTM.Import is clean and the wrapper is properly used in
CheckpointUNet.
|
@ericspod I've moved the wrapper to a different script, added docstrings and a test (which is literally a copy of the unet one). Let me know if this is sensible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
monai/networks/nets/unet.py (1)
25-25: Consider sorting__all__alphabetically.Static analysis suggests sorting:
["CheckpointUNet", "UNet", "Unet"].-__all__ = ["UNet", "Unet", "CheckpointUNet"] +__all__ = ["CheckpointUNet", "UNet", "Unet"]
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/networks/nets/unet.py(2 hunks)tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.pytests/networks/nets/test_checkpointunet.py
🪛 Ruff (0.14.3)
monai/networks/nets/unet.py
25-25: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: packaging
- GitHub Check: build-docs
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (ubuntu-latest)
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)
20-20: LGTM!Import is clean and appropriately placed.
302-324: Implementation and docstrings are solid.The subclass correctly wraps all connection-block components before delegating to the parent. Docstrings follow Google style per coding guidelines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/networks/nets/test_checkpointunet.py (1)
167-204: Add tests verifying checkpointing activates during training.All tests run in eval mode (via
eval_modeortest_script_save), so checkpointing is never engaged. Must verify: (a) forward pass matches UNet in eval mode, and (b) checkpointing works during training with gradients.Add two tests as suggested in the past review:
def test_checkpoint_parity_eval(self): """Verify CheckpointUNet matches UNet output in eval mode.""" torch.manual_seed(0) from monai.networks.nets.unet import UNet config = {"spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (16, 32, 64), "strides": (2, 2), "num_res_units": 1} unet = UNet(**config).to(device) checkpoint_unet = CheckpointUNet(**config).to(device) checkpoint_unet.load_state_dict(unet.state_dict()) test_input = torch.randn(2, 1, 32, 32).to(device) with eval_mode(unet), eval_mode(checkpoint_unet): out_unet = unet(test_input) out_checkpoint = checkpoint_unet(test_input) self.assertTrue(torch.allclose(out_unet, out_checkpoint, atol=1e-6)) def test_checkpoint_engages_training(self): """Verify checkpointing activates during training.""" net = CheckpointUNet( spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1 ).to(device) net.train() test_input = torch.randn(2, 1, 32, 32, requires_grad=True, device=device) output = net(test_input) loss = output.sum() loss.backward() self.assertIsNotNone(test_input.grad)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/nets/test_checkpointunet.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
tests/networks/nets/test_checkpointunet.py (3)
81-86: Add Google-style docstring.Per coding guidelines, test methods require docstrings describing purpose, parameters, and expected behavior.
Apply this diff:
@parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): + """Verify CheckpointUNet produces expected output shapes. + + Args: + input_param: UNet constructor arguments. + input_shape: Input tensor dimensions. + expected_shape: Expected output tensor dimensions. + """ net = CheckpointUNet(**input_param).to(device)
88-93: Add Google-style docstring.Per coding guidelines, test methods require docstrings.
Apply this diff:
def test_script(self): + """Verify CheckpointUNet is scriptable via TorchScript.""" net = CheckpointUNet(
95-99: Add Google-style docstring.Per coding guidelines, test methods require docstrings.
Apply this diff:
def test_ill_input_shape(self): + """Verify RuntimeError raised for mismatched input shape.""" net = CheckpointUNet(spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2))
🧹 Nitpick comments (1)
tests/networks/nets/test_checkpointunet.py (1)
119-139: Test logic is sound; minor redundancy in assertion.The test verifies gradient propagation during training, indirectly confirming checkpointing works. Line 139's
assertIsNotNone(grad_norm)is redundant sincegrad_normis always a tensor after line 134.Optionally remove the redundant assertion:
# gradient flow check grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None) self.assertGreater(grad_norm.item(), 0.0) # checkpointing should reduce activation memory use; we can't directly assert memory savings # but we can confirm no runtime errors and gradients propagate correctly - self.assertIsNotNone(grad_norm)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/nets/test_checkpointunet.py
🔇 Additional comments (2)
tests/networks/nets/test_checkpointunet.py (2)
1-23: Imports and setup look correct.All necessary components are imported and device selection follows standard patterns.
25-77: Test case definitions provide good coverage.Cases cover 2D/3D variants, different channel counts, and residual unit configurations.
This looks much better, thanks. Please do work on the testing issues and the DCO issue can be left until last. |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/networks/nets/test_checkpointunet.py (1)
125-186: Add Google-style docstrings for tests.
Docstrings are missing or not in Google style fortest_shape,test_script,test_checkpointing_equivalence_eval, andtest_checkpointing_activates_training. Please add compliant summaries and sections so the tests meet the documented requirement. As per coding guidelines@@ class TestCheckpointUNet(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): + """Validate CheckpointUNet output shapes across configurations. + + Args: + input_param: Mapping of constructor kwargs for the network under test. + input_shape: Shape tuple for the synthetic input tensor. + expected_shape: Expected output tensor shape. + """ net = CheckpointUNet(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) def test_script(self): - """ - TorchScript doesn't support activation-checkpointing (torch.utils.checkpoint) calls inside the module. - To keep the test suite validating TorchScript compatibility, script the plain UNet (which is scriptable), - rather than the CheckpointUNet wrapper that uses checkpointing internals. - """ + """Script the baseline UNet to maintain TorchScript coverage. + + Raises: + RuntimeError: If scripting the baseline UNet fails. + """ net = UNet( spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 ) test_data = torch.randn(16, 1, 32, 32) test_script_save(net, test_data) def test_checkpointing_equivalence_eval(self): - """Ensure that CheckpointUNet matches standard UNet in eval mode (checkpointing inactive).""" + """Confirm eval parity when checkpointing is inactive. + + Raises: + AssertionError: If shapes or outputs diverge beyond tolerances. + """ params = dict( spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1 ) @@ def test_checkpointing_activates_training(self): - """Ensure checkpointing triggers recomputation under training and gradients propagate.""" + """Verify checkpointing recomputes activations during training. + + Raises: + AssertionError: If gradients are not propagated through the model. + """ params = dict( spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1 )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/nets/test_checkpointunet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: packaging
- GitHub Check: build-docs
|
Thanks. I have fixed the testing issues. What is the easiest way to fix the DCO issue? |
Hi @ferreirafabio80, you can refer to the guide here: https://github.com/Project-MONAI/MONAI/pull/8554/checks?check_run_id=54816594479 |
Signed-off-by: Fabio Ferreira <f.ferreira@qureight.com>
I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: de2b6bd I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 66edcb5 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: feefcaa I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: f673ca1 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 69540ff I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 42ec757 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: a2e8474 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 4c4782e I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 515c659 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: da5a3a4 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 43dec88 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 84c0f48 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 5805515 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 1aa8e3c I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: 447d9f2 I, Fabio Ferreira <f.ferreira@qureight.com>, hereby add my Signed-off-by to this commit: b20a19e Signed-off-by: Fabio Ferreira <f.ferreira@qureight.com>
|
thank you @KumoLiu. Just fixed it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/networks/nets/test_checkpointunet.py (1)
126-183: Add Google-style docstrings with Args sections.Per coding guidelines, docstrings must describe each parameter. All test methods have brief one-liners but lack
Args:sections.Example for
test_shape:@parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): - """Validate CheckpointUNet output shapes across configurations.""" + """Validate CheckpointUNet output shapes across configurations. + + Args: + input_param: Dictionary of UNet constructor arguments. + input_shape: Tuple specifying input tensor dimensions. + expected_shape: Tuple specifying expected output tensor dimensions. + """ net = CheckpointUNet(**input_param).to(device)Apply similar updates to
test_script,test_checkpointing_equivalence_eval, andtest_checkpointing_activates_training.
🧹 Nitpick comments (2)
tests/networks/nets/test_checkpointunet.py (2)
149-153: Fragile weight synchronization; useload_state_dictinstead.Relying on
manual_seedbefore each instantiation is fragile—if initialization order differs betweenUNetandCheckpointUNet(e.g., wrapper layers init first), weights diverge. Useload_state_dictto ensure identical parameters.- torch.manual_seed(42) - net_plain = UNet(**params).to(device) - - torch.manual_seed(42) - net_ckpt = CheckpointUNet(**params).to(device) + torch.manual_seed(42) + net_plain = UNet(**params).to(device) + net_ckpt = CheckpointUNet(**params).to(device) + net_ckpt.load_state_dict(net_plain.state_dict())
164-165: Tighten tolerance for identical operations.A threshold of
1e-3is loose when both models have identical weights and run in eval mode. Use1e-6ortorch.allclosewith default tolerances.- diff = torch.mean(torch.abs(y_ckpt - y_plain)).item() - self.assertLess(diff, 1e-3, f"Eval-mode outputs differ more than expected (mean abs diff={diff:.6f})") + self.assertTrue( + torch.allclose(y_ckpt, y_plain, atol=1e-6, rtol=1e-5), + f"Eval-mode outputs differ: max abs diff={torch.max(torch.abs(y_ckpt - y_plain)).item():.6e}" + )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/nets/test_checkpointunet.py
Signed-off-by: Fabio Ferreira <f.ferreira@qureight.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/networks/nets/test_checkpointunet.py (1)
125-130: Consider adding parameter descriptions to docstring.The parameterized test has three parameters but the docstring doesn't describe them. Adding an Args section would improve clarity per coding guidelines.
- """Validate CheckpointUNet output shapes across configurations.""" + """Validate CheckpointUNet output shapes across configurations. + + Args: + input_param: Dictionary of UNet constructor arguments. + input_shape: Tuple specifying input tensor dimensions. + expected_shape: Tuple specifying expected output tensor dimensions. + """
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/nets/test_checkpointunet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: packaging
- GitHub Check: build-docs
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
🔇 Additional comments (1)
tests/networks/nets/test_checkpointunet.py (1)
158-174: LGTM: Training gradient flow verified.The test confirms checkpointing doesn't break backward pass and gradients flow correctly. Direct verification of activation recomputation would require introspection, so this approach is pragmatic.
Signed-off-by: Fabio Ferreira <f.ferreira@qureight.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/networks/nets/test_checkpointunet.py (1)
125-130: Consider documenting parameterized test arguments.The docstring omits the three arguments provided by
@parameterized.expand. While the brief docstring is present, adding anArgs:section would clarify whatinput_param,input_shape, andexpected_shaperepresent for future maintainers.Example:
@parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): - """Validate CheckpointUNet output shapes across configurations.""" + """Validate CheckpointUNet output shapes across configurations. + + Args: + input_param: Dictionary of UNet constructor arguments. + input_shape: Tuple specifying input tensor dimensions. + expected_shape: Tuple specifying expected output tensor dimensions. + """ net = CheckpointUNet(**input_param).to(device)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/nets/test_checkpointunet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: packaging
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.9)
🔇 Additional comments (1)
tests/networks/nets/test_checkpointunet.py (1)
160-176: Training test correctly verifies gradient flow.Logic is sound: the test confirms that backward pass completes through checkpointed blocks and produces non-zero gradients. The gradient aggregation on line 175 correctly sums across all parameters.
|
Hi @ferreirafabio80 it's looking pretty good here, if you can resolve the last conversation items with coderabbit (eg. update docstrings) and others we can merge this soon I think. |
Signed-off-by: Fabio Ferreira <f.ferreira@qureight.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/networks/nets/test_checkpointunet.py (1)
144-150: Set seed before creating input tensor.Input
xis created before seeds are set. While functionally correct (samexfeeds both networks), setting the seed first ensures full reproducibility.Apply this diff:
- x = torch.randn(2, 1, 32, 32, device=device) - torch.manual_seed(42) + x = torch.randn(2, 1, 32, 32, device=device) net_plain = UNet(**params).to(device) torch.manual_seed(42)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/nets/test_checkpointunet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
🔇 Additional comments (2)
tests/networks/nets/test_checkpointunet.py (2)
125-136: LGTM - shape validation is comprehensive.Parameterized tests cover 2D/3D and varied configurations. Docstring is clear.
166-182: LGTM - confirms gradient flow during training.Test validates that checkpointed blocks support backward pass. While it doesn't explicitly verify
checkpoint()is invoked, gradient propagation is the key functional requirement.
|
@ericspod sorry, missed that one. All addressed now. |
|
/build |
Description
Introduces an optional
use_checkpointingflag in theUNetimplementation. When enabled, intermediate activations in the encoder–decoder blocks are recomputed during the backward pass instead of being stored in memory._ActivationCheckpointWrapperwrapper around sub-blocks.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.